import numpy as np
import os
from nats_bench import create
import random

from exps.genotypes import Structure

INPUT = 'input'
OUTPUT = 'output'
OPS = ['avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'none', 'skip_connect']
NUM_OPS = len(OPS)
OP_SPOTS = 6
LONGEST_PATH_LENGTH = 3


class Natsbench(object):

    def __init__(self, data_path, search_space):
        self.search_space = search_space
        self.api = create(None, search_space, fast_mode=True, verbose=False)
        self.edge2index = {'1<-0': 0, '2<-0': 1, '2<-1': 2, '3<-0': 3, '3<-1': 4, '3<-2': 5}
        self.op_names = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
        self.max_nodes = 4

    def random_topology_func(self, op_names, max_nodes=4):
        # Return a random architecture
        def random_architecture():
            genotypes = []
            for i in range(1, max_nodes):
                xlist = []
                for j in range(i):
                    node_str = "{:}<-{:}".format(i, j)
                    op_name = random.choice(op_names)
                    xlist.append((op_name, j))
                genotypes.append(tuple(xlist))
            return Structure(genotypes)

        return random_architecture

    def random_size_func(self, info):
        # Return a random architecture
        def random_architecture():
            channels = []
            for i in range(info["numbers"]):
                channels.append(str(random.choice(info["candidates"])))
            return ":".join(channels)

        return random_architecture

    def dag_encoding(self, arch, deterministic=True, dataset='cifar10', st='tss'):

        # def loss_to_normalized_acc(loss):
        #     MEAN = 0.908192
        #     STD = 0.023961
        #     acc = 1 - loss / 100
        #     normalized = (acc - MEAN) / STD
        #     return torch.tensor(normalized, dtype=torch.float32)

        op_map = [OUTPUT, INPUT, *OPS]
        ops = self.get_op_list(arch)
        ops_idx = [OPS.index(v) for v in ops]
        ops = [INPUT, *ops, OUTPUT]
        ops_onehot = np.array([[i == op_map.index(op) for i in range(len(op_map))] for op in ops], dtype=np.float32)

        ops = [op + 1 for op in ops_idx]
        ops = [0, *ops, 6]

        matrix = np.array(
            [[0, 1, 1, 1, 0, 0, 0, 0],
             [0, 0, 0, 0, 1, 1, 0, 0],
             [0, 0, 0, 0, 0, 0, 1, 0],
             [0, 0, 0, 0, 0, 0, 0, 1],
             [0, 0, 0, 0, 0, 0, 1, 0],
             [0, 0, 0, 0, 0, 0, 0, 1],
             [0, 0, 0, 0, 0, 0, 0, 1],
             [0, 0, 0, 0, 0, 0, 0, 0]])

        # val_acc, test_acc, time_cost = self.get_details_info(arch, deterministic, dataset, st)
        val_acc, test_acc, time_cost = self.get_simul_train_epoch12_info(dataset, arch, deterministic=deterministic)
        dic = {
            'num_vertices': 8,
            'adjacency': matrix,
            'operations_oneshot': ops_onehot,
            'operations': ops,
            'mask': np.array([i < 8 for i in range(8)], dtype=np.float32),
            # 'val_acc': loss_to_normalized_acc(val_loss),
            # 'test_acc': loss_to_normalized_acc(test_loss)
            'val_acc': val_acc,
            'test_acc': test_acc,
            'time_cost': time_cost,
        }

        return dic

    # return the training info of 200 epochs.
    # return the training info of 12 epochs, since the above results does not include the time of training.
    def get_simul_train_epoch12_info(self, dataset, arch, deterministic=True):

        test_acc = 0.0  # 用 12 epochs选架构后, 再用 90 or 200 epochs 获取准确的 val_acc, test_acc.
        val_acc, _, time_cost, _ = self.api.simulate_train_eval(
            arch, dataset, hp="12"
        )

        return val_acc, test_acc, time_cost

    def get_simul_full_train_info(self, dataset, arch, deterministic=True):

        if dataset == "cifar10":
            xinfo = self.api.get_more_info(
                arch,
                dataset=dataset,
                hp="200" if self.search_space == "tss" else "90",
                is_random=False
            )
            test_acc = xinfo["test-accuracy"]
            xinfo = self.api.get_more_info(
                arch,
                dataset="cifar10-valid",
                hp="200" if self.search_space == "tss" else "90",
                is_random=False,
            )
            valid_acc = xinfo["valid-accuracy"]
        else:
            xinfo = self.api.get_more_info(
                arch,
                dataset=dataset,
                hp="200" if self.search_space == "tss" else "90",
                is_random=False
            )
            valid_acc = xinfo["valid-accuracy"]
            test_acc = xinfo["test-accuracy"]

        return valid_acc, test_acc

    def get_op_list(self, string):
        # given a string, get the list of operations

        tokens = string.split('|')
        ops = [t.split('~')[0] for i, t in enumerate(tokens) if i not in [0, 2, 5, 9]]
        # ops[2], ops[3] = ops[3], ops[2]   # 调换之后效果没有不调换的好.

        return ops

